import numpy as np
from Utility import *

# Get Match data by searching distance 
def GetMatchDataByDistance(Args):
   
   Feat = 2 # Dimension of covariates 
   # Construct data pairs 
   LeftCovariate = np.zeros((Args.S + Args.TestS, Feat))
   LeftCovariate = np.random.uniform(size = (Args.S + Args.TestS , Feat)); 

   RightCovariate = np.zeros((Args.S + Args.TestS, Feat))
   for i in range(len(LeftCovariate)):
      while(1):
         MatchCovariate = np.random.uniform(size = (1, Feat))
         if np.linalg.norm(MatchCovariate - LeftCovariate[i], ord = 2) < Args.Matchdistance:
            # print("Match %d data"%(i))
            RightCovariate[i] = MatchCovariate
            break
   CovariatePairs = np.zeros((Args.S + Args.TestS, Feat * 2))
   CovariatePairs[:, :Feat] = LeftCovariate
   CovariatePairs[:, Feat:] = RightCovariate 

   # Assign the convariate pairs to control/treatment
   Assignment = np.random.randint(0, 2, size =Args.S + Args.TestS); 

   # Data that contains the covariate pairs, responses and the assignments
   Data = np.zeros((Args.S + Args.TestS, Feat * 2 + 5))# Covaraite pairs + reponses from the pairs + assignment + enrichment label 
   Data[:, : Feat * 2] = CovariatePairs
   Data[:, -3] = Assignment # assignment to iid covariates 
   Data[:,-2] = 1 - Assignment # assignment to paring covariates 

   # calculate resonses for the treatment and control
   Data[:, Feat * 2] = Data[:, 0] + 2 *  Data[:, 1] - Data[:, 0] * Data[:, 1] + np.random.normal(size = Args.S + Args.TestS) * np.sqrt(Args.noise_var)
   Data[:, Feat * 2 + 1] = Data[:, 2] + 2 *  Data[:, 3] - Data[:, 2] * Data[:, 3] + np.random.normal(size = Args.S + Args.TestS) * np.sqrt(Args.noise_var)

   for i in range(Args.S + Args.TestS):
      if Assignment[i] == 0: # Left and right convariates are assigned to control and treatment 
         if Args.DataType == 'Syn2':
            if (Data[i, 2] + Args.Sep < Data[i, 3]):
               Data[i, 2 * Feat + 1]+=1           
         if Data[i, 2 * Feat + 1] - Data[i, 2 * Feat] > Args.RespondingThres:
            Data[i, -1] = 1
      elif Assignment[i] == 1: # Left and right convariates are assigned to treatment and control 
         if Args.DataType == 'Syn2':
            if (Data[i, 0] + Args.Sep < Data[i, 1]):
               Data[i, 2 * Feat]+=1         
         if Data[i, 2 * Feat] - Data[i, 2 * Feat + 1] > Args.RespondingThres:
            Data[i, -1] = 1
   # # label responsive region for the enrichment classifier testing data
   Data[Args.S:, -1] = 0
   for i in range(Args.S, Args.S + Args.TestS):
      if Args.DataType == 'Syn2':
         if (Data[i, 0] + Args.Sep < Data[i, 1]):  
            Data[i, -1] = 1

   MatchPairExpData = Data[:Args.S]
   EnrichmentClsTestingData = Data[Args.S:]
   return MatchPairExpData, EnrichmentClsTestingData

